import numpy as np
import pandas as pd
import random
import datetime
import ase.data as data
from pathlib import Path
from ase.io import write
import os
import pickle

import math

from gymnasium.spaces import Box, Dict

from ray.rllib.env.multi_agent_env import MultiAgentEnv

from .world import World

from UnitCell_Environment.unitcell_environment.env.observation import CombinedFlatObserver
from UnitCell_Environment.unitcell_environment.env.utils import ELEMENTS, COMP

from ray.rllib.utils.typing import (
    MultiAgentDict,
)

class HierParalUnitCellEnvironment(MultiAgentEnv):

    STEPSIZE_L = "2_stepsize_"
    ACCEPT_L = "3_accept_"

    prefixes = [STEPSIZE_L, ACCEPT_L]
    
    def __init__(self, config):
        super().__init__()

        self.env_name = config.get("env_name", "")
        self.debug = config.get("debug", False)
        self.cif = config.get("cif", None)
        self.fmax_limit = config.get("fmax", 0.05)
        self.gamma = config.get("gamma", 0.99)
        self.step_cost = config.get("step_cost", -0.1)
        self.min_gnorm = config.get("min_gnorm", 0.0001)
        self.use_nstep_feature = config.get("use_nstep_feature", False)
        self.use_gnorm_feature = config.get("use_gnorm_feature", False)
        self.use_comp_feature = config.get("use_comp_feature", False)
        self.variable_step_size = config.get("variable_step_size", None)
        self.store_trajectory = config.get("store_trajectory", False)
        self.output_dir = config.get("output_dir", None)
        self.coefs = []

        self.world = World(config)
        self.max_energy = self.world.energy
        self._min_ac_energy = self.world.energy
        self.prev_energy = self.world.energy

        self.rej_num = 0
        self.zero_rew_num = 0

        self.max_episodes = config.get("max_episodes", 1)
        self._episode_n = 0
        self.optimisation_step = 0
        self.last_optimisation_len = 0
        self.discounted_reward = 0

        self.min_stepsize = config.get("stepsize_min", 0)
        self.max_stepsize = config.get("stepsize_max", 1)

        self.reward_type = config.get("reward_type", "log-grad-drop")

        self._obs_space_in_preferred_format = True
        self._action_space_in_preferred_format = True

        self._min_rel_energy = 0

        self.energy_plot_list = {"Step": [], "Energy": [], "Time": [], "Fmax": []}
        self.traj = {'atom_positions': [], 'cell': [], 'atomic_number': []}

        self.config = config

        self.levels = ["2_stepsize_"]

        comp_list = config.get("comp", "SrTiO3x8").split(",")
        self.create_possible_agents(comp_list)

        self.action_space = Dict({})
        self.observation_space = Dict({})

        for a in self.possible_agents:
            
            action_dim = 3
            low_v = [-self.max_stepsize]*3
            high_v = [self.max_stepsize]*3

            if self.variable_step_size == "learn":
                action_dim += 1
                low_v.append(-1)
                high_v.append(1)
                
            self.action_space[a] = Box(np.array(low_v), 
                                        np.array(high_v),
                                        shape=(action_dim, ), 
                                        dtype=np.float32)                

        agents_dict = {}
        for agent in self.possible_agents:
            agents_dict[agent] = self.remove_prefix(agent)

        self.motif_observer = CombinedFlatObserver(self.world, agents_dict)

        self.agent_name_mapping = dict(
            zip(self.possible_agents, list(range(len(self.possible_agents))))
        )

        self.num_moves = 0

        env_obs_len = len(self.added_env_obs())
        for agent in self.possible_agents:
            self.observation_space[agent] = self.motif_observer.observation_space(env_obs_len)

        self.episode_start_time = datetime.datetime.now()

        self.initialize_learning()


    # if agent is None, update observations of all agents 
    def update_observations(self):

        env_obs = self.added_env_obs()
        self.motif_observer.update_observations(env_obs)

        for a in self.agents:
            self._observations[a] = self.motif_observer.observe(a)

    
    # additional variables from the environment we want to add to observation
    def added_env_obs(self):

        env_obs = []

        if self.use_nstep_feature:
            env_obs.append(self.num_moves)
        
        if self.use_comp_feature:
            
            elements = ELEMENTS[self.world.comp_name]
            radii = []
            for el in elements:
                radii.append(data.covalent_radii[data.atomic_numbers[el]])

            radii.extend([0]*(5 - len(radii)))
            radii.sort()
            env_obs.extend(radii)

        return env_obs


    def observe(self, agent):

        return self._observations[agent]

    def reset(self, seed=None, cif=None, return_info=False, options=None):

        self._obs_space_in_preferred_format = True
        self._action_space_in_preferred_format = True

        # Save final structure, the energy and the trajectory if needed
        if len(self.energy_plot_list["Step"]) > 1 and \
           self.output_dir is not None:

            df_plot = pd.DataFrame(self.energy_plot_list)

            Path(self.output_dir).mkdir(parents=True, exist_ok=True)

            with open(f"{self.output_dir}/opt_traj.csv", 'a') as f:
                df_plot.to_csv(f, mode='a', index=False, header=f.tell()==0)

            final_str_dir = f"{self.output_dir}/final_structures"
            Path(final_str_dir).mkdir(parents=True, exist_ok=True)

            cif_name = os.path.basename(self.cif) if self.cif is not None else "final_structure.cif"
            if os.path.exists(f"{final_str_dir}/{cif_name}"):
                cif_name = Path(cif_name).stem + datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + ".cif"

            write(f"{final_str_dir}/{cif_name}", self.world.atoms)
            pickle.dump(self.traj, open(f"{final_str_dir}/{Path(cif_name).stem + '.pkl'}", "wb"))

            if self.store_trajectory:
                with open(f"{self.output_dir}/history_{self.world.comp_name}_{self.env_name}.csv", 'a') as f:
                    self.history().to_csv(f, mode='a', index=False, header=f.tell()==0)

        self.cif = cif # update cif with new cif name

        self.energy_plot_list = {"Step": [], "Energy": [], "Time": [], "Fmax": []}
        self.episode_start_time = datetime.datetime.now()

        # print(f"Reset episode: {self.episode_n == 0}, episode N: {self.episode_n}")
        self.world.initialize(seed=seed, cif=self.cif, update_atoms=self.episode_n == 0)

        if self.fmax < self.fmax_limit or self._episode_n >= self.max_episodes - 1:
            self.last_optimisation_len = self.optimisation_step
            self.optimisation_step = 0
            self._episode_n = 0
            self.coefs = []
        else:
            self._episode_n = self._episode_n + 1

        self.max_energy = self.world.energy
        self._min_ac_energy = self.world.energy
        self._min_rel_energy = self.world.energy
        self.prev_energy = self.world.energy
        self.discounted_reward = 0

        self.world.save_checkpoint()
        
        if self.store_trajectory:
            self.reset_history()

        self.initialize_learning()

        return self._observations, self.infos

    # this function does everything to reset the environment 
    # and start learning from the beginning except the world resetting
    def initialize_learning(self):
        
        if self.num_moves > 0:
            self.max_cycles = self.config.get("max_cycles", 100) if self.num_moves > 0 else random.randint(1, self.config.get("max_cycles", 100))

        else:
            self.max_cycles = random.randint(1, self.config.get("max_cycles", 100))

        self.rej_num = 0
        self.zero_rew_num = 0
        self.num_moves = 0

        self.initialize_agents(self.world.comp_name)
        self.rewards = {agent: 0 for agent in self.agents}
        # self._cumulative_rewards = {agent: 0 for agent in self.agents}
        self.terminations = {agent: False for agent in self.agents}
        self.truncations = {agent: False for agent in self.agents}
        self.terminations["__all__"] = False
        self.truncations["__all__"] = False
        self.infos = {agent: {} for agent in self.agents}
        
        self.motif_observer.agents = {}
        for agent in self.agents:
            self.motif_observer.agents[agent] = self.remove_prefix(agent)

        self.energy_plot_list["Step"].append(0)
        self.energy_plot_list["Energy"].append(self.world.energy)
        self.energy_plot_list["Time"].append((datetime.datetime.now() - self.episode_start_time).total_seconds())
        self.energy_plot_list["Fmax"].append(self.world.fmax)

        self.traj = {'atom_positions': [self.world.atoms.positions.tolist()], 
                     'cell': [self.world.cell.tolist()], 
                     'atomic_number': [self.world.atoms.get_atomic_numbers().tolist()]}

        self._observations = {}
        self.update_observations()

        # self._agent_selector.reinit(self.world.comp_name)
        # self.agent_selection = self._agent_selector.selected_agent


    def step(self, actions, sim_data=None):

        if not actions:
            self.agents = []
            return {}, {}, {}, {}, {}

        flat_actions = None
        self.rewards = {agent: 0 for agent in self.agents}
        first_agent = list(actions.keys())[0]

        if self.is_step_agent(first_agent):
            flat_actions = {}
            for a in actions:
                flat_actions[a] = actions[a]

            flat_actions = self.unscale_actions(flat_actions)

        elif self.is_accept_agent(first_agent):
            #Not implemented yet
            assert False, "Accept level agent in parallel environment is not implemented yet"

        else:
            assert f"Invalid agent name: {first_agent}"

        # perform step
        if self.is_step_agent(first_agent):
            if self.ACCEPT_L in self.levels:
                self.step_stepsize(flat_actions)

            else:
                self.step_stepsize(flat_actions, sim_data=sim_data)
                self.post_process_step(accept=1)

        self.print_debug(f"Time for energy calculation: {self.world.time}")

        self.terminations["__all__"] = all(self.terminations.values())
        self.truncations["__all__"] = all(self.truncations.values())

        return self._observations, self.rewards, self.terminations, self.truncations, self.infos


    def post_process_step(self, accept=1):

        self.rej_num = 0
        self.optimisation_step += 1
        self.num_moves += 1
        
        self.terminations = {
            agent: self.num_moves >= self.max_cycles or self.fmax < self.fmax_limit for agent in self.agents
        }

        if all(val is True for val in self.terminations.values()):
        
            self.stays = {agent: False for agent in self.agents}

        self.discounted_reward += sum(self.rewards.values()) * math.pow(self.gamma, self.num_moves)


    def step_stepsize(self, actions, sim_data=None):

        old_obs = {}
        for agent in self.agents:
            old_obs[agent] = self._observations[agent].copy()

        self.print_debug(f"Agents take actions {actions}")

        self.prev_energy = self.world.energy
        prev_fmean = self.world.fmean

        agent_actions = {}
        for item in actions:
            agent_actions[self.remove_prefix(item)] = actions[item]

        if sim_data == None:
            self.world.take_action(1, agent_actions)   
        else:
            print("Update world from data")
            self.world.update_from_data(sim_data)

        trial_energy = self.world.energy
        if self.prev_energy == trial_energy:
            self.print_debug(f"Energy did not change: {self.prev_energy} {trial_energy}")

        self.rewards = {agent: 0 for agent in self.agents}

        for a in self.agents:

            agent = self.world.agents[self.remove_prefix(a)]

            if self.reward_type in ["log-grad-drop"]:
                self.rewards[a] = math.log2(max(agent.prev_gnorm, self.min_gnorm)) \
                                - math.log2(max(agent.gnorm, self.min_gnorm))
                
            elif self.reward_type == "log-grad-drop-plus-average":
                self.rewards[a] = math.log2(max(agent.prev_gnorm, self.min_gnorm)) \
                                - math.log2(max(agent.gnorm, self.min_gnorm)) \
                                + math.log2(max(prev_fmean, self.min_gnorm)) \
                                - math.log2(max(self.world.fmean, self.min_gnorm))
                
            elif self.reward_type == "log-grad-drop-plus-const":
                self.rewards[a] = math.log2(max(agent.prev_gnorm, self.min_gnorm)) \
                                - math.log2(max(agent.gnorm, self.min_gnorm)) + self.step_cost

            # print(f"Reward for agent {a}: {self.rewards[a]}")

        self.update_observations()
        trial_obs = {}
        for agent in self.agents:
            trial_obs[agent] = self._observations[agent].copy()

        if self.store_trajectory:
            for item in actions:
                self._history[f"row_{self._history_len + 1}"] = [self.remove_prefix(item), 
                                                        self.num_moves, 
                                                        1,
                                                        list(actions[item]), 
                                                        list(old_obs[item]),
                                                        list(trial_obs[item]),
                                                        self.rewards[item], 
                                                        trial_energy, 
                                                        self.prev_energy, 
                                                        "A"]
                self._history_len += 1

        if self.output_dir is not None:
            self.energy_plot_list["Step"].append(self.num_moves+1)
            self.energy_plot_list["Energy"].append(trial_energy)
            self.energy_plot_list["Time"].append((datetime.datetime.now() - self.episode_start_time).total_seconds())
            self.energy_plot_list["Fmax"].append(self.world.fmax)

            self.traj['atom_positions'].append(self.world.atoms.positions.tolist())
            self.traj["cell"].append(self.world.cell.tolist())
            self.traj["atomic_number"].append(self.world.atoms.get_atomic_numbers().tolist())


    def print_debug(self, text):

        if self.debug:
            print(text) 

    
    # return history as dataframe
    def history(self):

        if self.store_trajectory:
            return pd.DataFrame.from_dict(self._history, orient='index', 
                                          columns=['agent', 
                                                    'step', 
                                                    'action', 
                                                    'action_parameter', 
                                                    'old_observation',
                                                    'trial_observation',
                                                    'reward', 
                                                    'trial_energy',
                                                    'old_energy', 
                                                    'result'])
        
        else:
            return None
    

    def reset_history(self):

        if not self.store_trajectory:
            return

        self._history = {}
        self._history_len = 0
        

    def save_cif(self, filename):

        self.world.atoms.write(filename, format='cif')
        min_filename = filename[:-4] + "_min.cif" if filename.endswith(".cif") else filename + "_min.cif"
        self.world.min_atoms.write(min_filename, format="cif")


    @staticmethod
    def remove_prefix(agent):

        res = agent

        for prefix in HierParalUnitCellEnvironment.prefixes:

            res = res.removeprefix(prefix)

        return res


    @staticmethod
    def prefix(agent):

        return agent.split("_", 1)[0] + "_" + agent.split("_", 1)[1]
    

    def agent_type(self, agent):
        if agent.startswith(self.STEPSIZE_L):
            return self.STEPSIZE_L
        if agent.startswith(self.ACCEPT_L):
            return self.ACCEPT_L
        
    def is_step_agent(self, agent):
        return agent.startswith(self.STEPSIZE_L)
    
    def is_accept_agent(self, agent):
        return agent.startswith(self.ACCEPT_L)
    
    
    def observation_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
        
        if agent_ids is None:
            agent_ids = self.agents
        
        return {id: self.observation_space[id].sample() for id in agent_ids}


    def action_space_sample(self, agent_ids: list = None) -> MultiAgentDict:
        
        if agent_ids is None:
            agent_ids = self.agents
        
        return {id: self.action_space[id].sample() for id in agent_ids}


    def action_space_contains(self, x: MultiAgentDict) -> bool:
        
        if not isinstance(x, dict):
            return False
        
        return all(self.action_space[key].contains(x[key]) for key in x.keys())


    def observation_space_contains(self, x: MultiAgentDict) -> bool:
       
        if not isinstance(x, dict):
            return False
       
        return all(self.observation_space[key].contains(x[key]) for key in x.keys())


    def unscale_actions(self, actions_dict):

        # unscale actions if needed
        if self.variable_step_size == "learn":

            for key in actions_dict.keys():
                if self.is_step_agent(key):
                    self.coefs.append(actions_dict[key][3])
                    actions_dict[key] = actions_dict[key][:3] * actions_dict[key][3]

        elif self.variable_step_size in ["log_gnorm", "gnorm"]: 
                
            for key in actions_dict.keys():
                if self.is_step_agent(key):
                    coef = self.world.variable_step_size_coef(self.world.agents[self.remove_prefix(key)])
                    self.coefs.append(coef)
                    actions_dict[key] = actions_dict[key] * coef 

        else:
            # check if actions are out of bounderies and normalize them to (-self.max_stepsize, self.max_stepsize)
            min_value = min(min(values) for values in actions_dict.values())
            max_value = max(max(values) for values in actions_dict.values())

            max_abs_value = max([-min_value, max_value, self.max_stepsize])

            if max_abs_value > self.max_stepsize:
                for key in actions_dict.keys():
                    actions_dict[key] = actions_dict[key] * self.max_stepsize / max_abs_value

        return actions_dict


    def create_possible_agents(self, comp_list):

        for comp in comp_list:
            for el in ELEMENTS[comp]:
                for i in range(COMP[comp][el][0]):
                    self.possible_agents.append(f"{self.STEPSIZE_L}{el}_{i}")

        self.possible_agents = sorted(list(set(self.possible_agents)))


    def initialize_agents(self, comp):

        self.agents = []

        for el in ELEMENTS[comp]:
            for i in range(COMP[comp][el][0]):
                self.agents.append(f"{self.STEPSIZE_L}{el}_{i}")


    @property
    def energy(self):

        return self.world.energy
    
        
    @property
    def init_energy(self):

        return self.world.init_energy
    

    @property
    def init_relaxed_energy(self):

        return self.world.init_relaxed_energy
    

    @property
    def last_relaxed_energy(self):

        return self.world.last_relaxed_energy
    

    @property
    def test_energy(self):

        return self.world.test_energy
         

    @property
    def min_ac_energy(self):

        return self._min_ac_energy


    @property
    def min_energy(self):

        return self.world.min_energy
    

    @property
    def min_rel_energy(self):

        return self._min_rel_energy
    
    
    @property
    def episode_n(self):

        return self._episode_n
    

    @property
    def fmax(self):

        return self.world.fmax
    

    @property
    def fmean(self):

        return self.world.fmean
    

    @property
    def e_calc_time(self):

        return self.world.time
    
    @property
    def observations(self):

        return self._observations
